#define patchSide 8

#define WGS_W 8
#define WGS_H 8

#define N 16
#define NSHIFT 4

__kernel __attribute__((reqd_work_group_size(WGS_W, WGS_H, 1)))
void patches(
             __global unsigned*   w_ind,
             __global unsigned*   h_ind,
             int                  w_ind_size,
             int                  h_ind_size,
             __global float*      img,
             int                  width,
             int                  height,
             float                threshold,
             __global unsigned*   offsets
             )
{
    int ind_i = get_global_id(0);
    int ind_j = get_global_id(1);
    
    if(ind_i >= w_ind_size)
        return;
    
    if(ind_j >= h_ind_size)
        return;
    
    int i = w_ind[ind_i];
    int j = h_ind[ind_j];
    
    int local_i = get_local_id(0);
    int local_j = get_local_id(1);
    
    int local_idx = local_j*WGS_W + local_i;
    
    typedef struct dist_t { unsigned offset; float dist; /*float pad;*/ } dist_t;
    
    __local dist_t buf[WGS_W*WGS_H][N];
    __local dist_t* _dists = buf[local_idx];

//    dist_t _dists[N];
    
    // insert original patch first
    const int offsetOrg = j*width + i;
#pragma unroll
    for(int i = 0; i < N; i++)
    {
        _dists[i].offset = offsetOrg;
        _dists[i].dist = 0;
    }
    int _outSize = 1;
    //

    float8 p1 = vload8(0, img + offsetOrg + 0*width);
    float8 p2 = vload8(0, img + offsetOrg + 1*width);
    float8 p3 = vload8(0, img + offsetOrg + 2*width);
    float8 p4 = vload8(0, img + offsetOrg + 3*width);
    float8 p5 = vload8(0, img + offsetOrg + 4*width);
    float8 p6 = vload8(0, img + offsetOrg + 5*width);
    float8 p7 = vload8(0, img + offsetOrg + 6*width);
    float8 p8 = vload8(0, img + offsetOrg + 7*width);
    
    const int halfSearchWindow = patchSide;
    for(int dj = -halfSearchWindow; dj < halfSearchWindow; dj++)
    {
        int y = j + dj;
        int _offsetCur = y*width;
        
        for(int di = -halfSearchWindow; di < halfSearchWindow; di++)
        {
            int x = i + di;
            
            // ssd
            int offsetCur = _offsetCur + x;
            
            float8 dp1 = vload8(0, img + offsetCur) - p1;
            float8 dp2 = vload8(0, img + (offsetCur += width)) - p2;
            float8 dp3 = vload8(0, img + (offsetCur += width)) - p3;
            float8 dp4 = vload8(0, img + (offsetCur += width)) - p4;
            float8 dp5 = vload8(0, img + (offsetCur += width)) - p5;
            float8 dp6 = vload8(0, img + (offsetCur += width)) - p6;
            float8 dp7 = vload8(0, img + (offsetCur += width)) - p7;
            float8 dp8 = vload8(0, img + (offsetCur += width)) - p8;
            
            float dist =
                dot(dp1.lo, dp1.lo) + dot(dp1.hi, dp1.hi) +
                dot(dp2.lo, dp2.lo) + dot(dp2.hi, dp2.hi) +
                dot(dp3.lo, dp3.lo) + dot(dp3.hi, dp3.hi) +
                dot(dp4.lo, dp4.lo) + dot(dp4.hi, dp4.hi) +
                dot(dp5.lo, dp5.lo) + dot(dp5.hi, dp5.hi) +
                dot(dp6.lo, dp6.lo) + dot(dp6.hi, dp6.hi) +
                dot(dp7.lo, dp7.lo) + dot(dp7.hi, dp7.hi) +
                dot(dp8.lo, dp8.lo) + dot(dp8.hi, dp8.hi);
            
            // insert into the offsets table
            if(dist < threshold)
            {
                unsigned imin = 1;
                unsigned imax = _outSize;
                
                while (imin < imax)
                {
                    int imid = (imin + imax)>>1;
                    
//                    unsigned dd_m = _dists[imid].dist < dist ? (unsigned)-1 : 0;
//                    imin = ((imid + 1) & dd_m) | (imin & ~dd_m);
//                    imax = (imax & dd_m) | (imid & ~dd_m);
                    if (_dists[imid].dist < dist)
                        imin = imid + 1;
                    else
                        imax = imid;
                }
                
                int insertPos = imin;
                if(insertPos < N)
                {
                    _outSize += (_outSize < N);
//                    for(int k = _outSize-1; k > 0; k--)
//                        _dists[k] = _dists[k-(k > insertPos)];
                    for(int k = _outSize-1; k > insertPos; k--)
                        _dists[k] = _dists[k-1];
                    _dists[insertPos].offset = y*width + x;
                    _dists[insertPos].dist = dist;
                }
            }
            //
        }
    }
    
    int offs = (ind_j*w_ind_size + ind_i)<<(NSHIFT);
    __global unsigned* g_offsets = offsets + offs;
#pragma unroll
    for(int i = 0; i < N; i++)
        g_offsets[i] = _dists[i].offset;
}
